
class DigitEmbedding():
    def __init__(self, num_digits, max_seq_len):
        super(DigitEmbedding, self).__init__()
        self.num_digits = num_digits # 数字的最大位数
        self.max_seq_len = max_seq_len # 序列的最大长度

    def digit_splitter(self, input_seq):
        output = []
        for num in input_seq:
            digits = []
            for j in range(self.num_digits):
                digit = num % 10
                digits.append(digit)
                num //= 10
            digits.reverse()
            output.append(digits)
        return output

    def digit_decoder(self, encoded_seq):
        output = []
        for digits in encoded_seq:
            num = 0
            for k in range(self.num_digits):
                num += digits[k] * 10 ** (self.num_digits - k - 1)
            output.append(num)
        return output

    def pad_sequence(self, sequence):
        padded_sequence = []
        for digits in sequence:
            padded_digits = digits + [0.0] * (self.num_digits - len(digits))
            padded_sequence.extend(padded_digits)
        padded_sequence += [0.0] * ((self.max_seq_len - len(sequence)) * self.num_digits)
        return padded_sequence

    def encode(self, input_seq):
        # Split input sequence into digits
        digit_seq = self.digit_splitter(input_seq)

        # Pad sequence to max length
        padded_seq = self.pad_sequence(digit_seq)

        return padded_seq

    def decode(self, encoded_seq):
        # Reshape the sequence
        reshaped_seq = [encoded_seq[i:i+self.num_digits] for i in range(0, len(encoded_seq), self.num_digits)]

        # Remove padding
        unpadded_seq = [digits for digits in reshaped_seq if not all(digit == 0.0 for digit in digits)]

        # Decode digits
        output = self.digit_decoder(unpadded_seq)

        return output